Red Chainer で学習済みモデルを使って推論する方法
Red Chainer で学習済みモデルを使って推論する方法
結論、こんな感じ
code:rb
# 学習させた時と同じモデルを定義
predictor = MLP.new
# trainerで保存したスナップショットをモデルに読み込む
snapshot_filename = 'results/iris_result_20191106_224023/snapshot_epoch-30'
Chainer::Serializers::MarshalDeserializer.load_file(snapshot_filename, predictor, path: '/updater/model:main/@predictor/')
# テスト用のデータセットを取得
test_dataset = Dataset.get_iris
pass_count = 0
(0...test_dataset.size).each do |i|
variables, answer = test_dataseti # 変数をモデルに与えて推論結果を取得
prediction = predictor.(variables).data.argmax
pass_count += 1 if prediction == answer
print format("test%03d: prediction = %d, answer = %d\n",i, prediction, answer)
end
print "accuracy: #{pass_count * 100.0 / test_dataset.size}\n" こっちではコマンドライン引数でスナップショットを指定するようにした
以下はログ
/icons/hr.icon
code:py
chainer.serializers.load_npz('my_iris.net', loaded_net)
with chainer.using_config('train', False), chainer.using_config('enable_backprop', False):
y_test = loaded_net(x_test)
loaded_net は Sequential クラスで作ったネットワークにスナップショットを読み込んだもの。
trainerを使って訓練した場合はどうやって(実装レベルの話で)推論するのか?
結局これあれかな。ネットワークにスナップショットをloadしてやればいいだけかな?
code:rb
snapshot = 'path/to/snapshot_file'
predictor = MLP.new # 前もって定義したMLPクラス
Chainer::Serializers::MarshalDeserializer.load_file(snapshot, predictor)
Chainer::Serializers::MarshalDeserializer を読んだら、コードコメントに @param [object] obj Object to be deserialized. It must support serialization protocol. って書いあったわ。
つまり serialization protocol をサポートしているオブジェクトであればオーケー、ということだと思う
つまりそれは serialize メソッドが定義されていればOK、ということだ、と思う
当たり前だけど保存したオブジェクトと同じオブジェクトじゃないとloadできないよね
お、なんか保存したオブジェクトと別のオブジェクトであっても、pathをうまく指定することでいい感じにできるっぽい
というのを本家のリポジトリのサンプルコードを読んでいたら察せた
まあどんな path を指定すればいいか分からんけどなw
Chainer::Serializers::MarshalDeserializer.load_file の仮引数にもちゃんと path あった
とにかくめっちゃえぐいデバッグした。
code:txt
"/updater/iterator:main/current_position", "/updater/iterator:main/epoch", "/updater/iterator:main/is_new_epoch", "/updater/iterator:main/order", "/updater/iterator:main/previous_epoch_detail", "/updater/optimizer:main/t", "/updater/optimizer:main/epoch", "/updater/optimizer:main/@predictor/@l1/@w/v", "/updater/optimizer:main/@predictor/@l1/@b/v", "/updater/optimizer:main/@predictor/@l2/@w/v", "/updater/optimizer:main/@predictor/@l2/@b/v", "/updater/optimizer:main/@predictor/@l3/@w/v", "/updater/optimizer:main/@predictor/@l3/@b/v", "/updater/model:main/@predictor/@l1/@w", "/updater/model:main/@predictor/@l1/@b", "/updater/model:main/@predictor/@l2/@w", "/updater/model:main/@predictor/@l2/@b", "/updater/model:main/@predictor/@l3/@w", "/updater/model:main/@predictor/@l3/@b", "/updater/iteration", "/stop_trigger/previous_iteration", "/stop_trigger/previous_epoch_detail", "/extension_triggers/val/previous_iteration", "/extension_triggers/val/previous_epoch_detail", "/extensions/LogReport/_trigger/previous_iteration", "/extensions/LogReport/_trigger/previous_epoch_detail", "/extensions/LogReport/_log", "/extension_triggers/LogReport/previous_iteration", "/extension_triggers/LogReport/previous_epoch_detail", "/extension_triggers/Snapshot/previous_iteration", "/extension_triggers/Snapshot/previous_epoch_detail", "/extension_triggers/PrintReport/previous_iteration", "/extension_triggers/PrintReport/previous_epoch_detail", "/extension_triggers/ProgressBar/previous_iteration", "/extension_triggers/ProgressBar/previous_epoch_detail", "/_snapshot_elapsed_time" 手元のtrainerのsnapshotをloadした時、Chainer::Serializers::MarshalDeserializer クラスのインスタンス変数 @marshal_data の keys を表示したやつ↑
できたぞ〜〜〜
code:sh
$ ruby examples/iris/inference.rb
test000: prediction = 0, answer = 0
test001: prediction = 0, answer = 0
test002: prediction = 0, answer = 0
test003: prediction = 0, answer = 0
test004: prediction = 0, answer = 0
test005: prediction = 0, answer = 0
test006: prediction = 0, answer = 0
test007: prediction = 0, answer = 0
test008: prediction = 0, answer = 0
test009: prediction = 0, answer = 0
test010: prediction = 0, answer = 0
test011: prediction = 0, answer = 0
test012: prediction = 0, answer = 0
test013: prediction = 0, answer = 0
test014: prediction = 0, answer = 0
test015: prediction = 0, answer = 0
test016: prediction = 0, answer = 0
test017: prediction = 0, answer = 0
test018: prediction = 0, answer = 0
test019: prediction = 0, answer = 0
test020: prediction = 0, answer = 0
test021: prediction = 0, answer = 0
test022: prediction = 0, answer = 0
test023: prediction = 0, answer = 0
test024: prediction = 0, answer = 0
test025: prediction = 0, answer = 0
test026: prediction = 0, answer = 0
test027: prediction = 0, answer = 0
test028: prediction = 0, answer = 0
test029: prediction = 0, answer = 0
test030: prediction = 0, answer = 0
test031: prediction = 0, answer = 0
test032: prediction = 0, answer = 0
test033: prediction = 0, answer = 0
test034: prediction = 0, answer = 0
test035: prediction = 0, answer = 0
test036: prediction = 0, answer = 0
test037: prediction = 0, answer = 0
test038: prediction = 0, answer = 0
test039: prediction = 0, answer = 0
test040: prediction = 0, answer = 0
test041: prediction = 0, answer = 0
test042: prediction = 0, answer = 0
test043: prediction = 0, answer = 0
test044: prediction = 0, answer = 0
test045: prediction = 0, answer = 0
test046: prediction = 0, answer = 0
test047: prediction = 0, answer = 0
test048: prediction = 0, answer = 0
test049: prediction = 0, answer = 0
test050: prediction = 1, answer = 1
test051: prediction = 1, answer = 1
test052: prediction = 1, answer = 1
test053: prediction = 2, answer = 1
test054: prediction = 1, answer = 1
test055: prediction = 2, answer = 1
test056: prediction = 1, answer = 1
test057: prediction = 1, answer = 1
test058: prediction = 1, answer = 1
test059: prediction = 2, answer = 1
test060: prediction = 1, answer = 1
test061: prediction = 1, answer = 1
test062: prediction = 1, answer = 1
test063: prediction = 2, answer = 1
test064: prediction = 1, answer = 1
test065: prediction = 1, answer = 1
test066: prediction = 2, answer = 1
test067: prediction = 1, answer = 1
test068: prediction = 2, answer = 1
test069: prediction = 1, answer = 1
test070: prediction = 2, answer = 1
test071: prediction = 1, answer = 1
test072: prediction = 2, answer = 1
test073: prediction = 2, answer = 1
test074: prediction = 1, answer = 1
test075: prediction = 1, answer = 1
test076: prediction = 1, answer = 1
test077: prediction = 2, answer = 1
test078: prediction = 2, answer = 1
test079: prediction = 1, answer = 1
test080: prediction = 1, answer = 1
test081: prediction = 1, answer = 1
test082: prediction = 1, answer = 1
test083: prediction = 2, answer = 1
test084: prediction = 2, answer = 1
test085: prediction = 1, answer = 1
test086: prediction = 1, answer = 1
test087: prediction = 1, answer = 1
test088: prediction = 1, answer = 1
test089: prediction = 1, answer = 1
test090: prediction = 2, answer = 1
test091: prediction = 2, answer = 1
test092: prediction = 1, answer = 1
test093: prediction = 1, answer = 1
test094: prediction = 2, answer = 1
test095: prediction = 1, answer = 1
test096: prediction = 1, answer = 1
test097: prediction = 1, answer = 1
test098: prediction = 1, answer = 1
test099: prediction = 1, answer = 1
test100: prediction = 2, answer = 2
test101: prediction = 2, answer = 2
test102: prediction = 2, answer = 2
test103: prediction = 2, answer = 2
test104: prediction = 2, answer = 2
test105: prediction = 2, answer = 2
test106: prediction = 2, answer = 2
test107: prediction = 2, answer = 2
test108: prediction = 2, answer = 2
test109: prediction = 2, answer = 2
test110: prediction = 2, answer = 2
test111: prediction = 2, answer = 2
test112: prediction = 2, answer = 2
test113: prediction = 2, answer = 2
test114: prediction = 2, answer = 2
test115: prediction = 2, answer = 2
test116: prediction = 2, answer = 2
test117: prediction = 2, answer = 2
test118: prediction = 2, answer = 2
test119: prediction = 2, answer = 2
test120: prediction = 2, answer = 2
test121: prediction = 2, answer = 2
test122: prediction = 2, answer = 2
test123: prediction = 2, answer = 2
test124: prediction = 2, answer = 2
test125: prediction = 2, answer = 2
test126: prediction = 2, answer = 2
test127: prediction = 2, answer = 2
test128: prediction = 2, answer = 2
test129: prediction = 2, answer = 2
test130: prediction = 2, answer = 2
test131: prediction = 2, answer = 2
test132: prediction = 2, answer = 2
test133: prediction = 2, answer = 2
test134: prediction = 2, answer = 2
test135: prediction = 2, answer = 2
test136: prediction = 2, answer = 2
test137: prediction = 2, answer = 2
test138: prediction = 2, answer = 2
test139: prediction = 2, answer = 2
test140: prediction = 2, answer = 2
test141: prediction = 2, answer = 2
test142: prediction = 2, answer = 2
test143: prediction = 2, answer = 2
test144: prediction = 2, answer = 2
test145: prediction = 2, answer = 2
test146: prediction = 2, answer = 2
test147: prediction = 2, answer = 2
test148: prediction = 2, answer = 2
test149: prediction = 2, answer = 2
accuracy: 89.33333333333333